#include <assert.h>
#include <random>
#include <onnxruntime_cxx_api.h>
#include "cpu_provider_factory.h"
#include <adjust_result.h>

// 随机生成颜色
std::vector<int> generateRandomColor() {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<double> dis(0.0, 1.0);

    std::vector<int> color(3);
    for (int i = 0; i < 3; ++i) {
        color[i] = static_cast<int>(dis(gen) * 255);
    }

    return color;
}

int main(int argc, char* argv[]) {
    // // 模型路径，图片路径，缺陷阈值，重叠阈值
    // const char* model_path = "../models/best.onnx";
    // std::string imgPath = "../data/3.bmp";
    // std::string namesPath = "../type.names";
    // float threshold = 0.4;
    // float nms_threshold = 0.4;
    // 自动读取模型路径，图片路径，缺陷阈值，重叠阈值
    std::string model_path_;
    std::string imgPath;
    std::string namesPath;
    float threshold;
    float nms_threshold;
        // 打开配置文件并读取配置
    std::ifstream configFile("../config.txt");
    if (configFile.is_open()) {
        configFile >> model_path_ >> imgPath >> namesPath >> threshold >> nms_threshold;
        configFile.close();

        std::cout << "Model Path: " << model_path_ << std::endl;
        std::cout << "Image Path: " << imgPath << std::endl;
        std::cout << "Names Path: " << namesPath << std::endl;
        std::cout << "Threshold: " << threshold << std::endl;
        std::cout << "NMS Threshold: " << nms_threshold << std::endl;
    } else
        std::cerr << "Failed to open config file." << std::endl;
    const char* model_path = model_path_.c_str();

    // 图片变换
    cv::Mat inputImage = cv::imread(imgPath);
    if (inputImage.empty()) {
        std::cerr << "Failed to load image." << std::endl;
        return 1;
    }
        // 获取图片尺寸
    int y = inputImage.rows;
    int x = inputImage.cols;
        // 图片尺寸变换
    cv::Mat image0 = resizeImage(inputImage, y, x);
        // 图像归一化
    std::vector<float> input_image_ = nchwImage(image0);
    
    // 读取缺陷标志文件
    std::ifstream inputFile(namesPath);
    if (!inputFile.is_open()) {
        std::cerr << "Failed to open the file." << std::endl;
        return 1;
    }
    std::vector<std::string> typeNames;
    std::string line;
    while (std::getline(inputFile, line)) 
        typeNames.push_back(line);
    inputFile.close();
    //     // 打印缺陷标志文件内容
    // std::cout << "Number of elements: " << typeNames.size() << std::endl;
    // for (const std::string &typeName : typeNames) 
    //     std::cout << typeName << std::endl;

    // 缺陷颜色标识随机
    int numColors = typeNames.size();
    std::vector<std::vector<int>> colors;
    for (int i = 0; i < numColors; ++i) 
        colors.push_back(generateRandomColor());
    //     // 打印颜色种类
    // for (const auto &color : colors) 
    //     std::cout << "R: " << color[0] << ", G: " << color[1] << ", B: " << color[2] << std::endl;

    // 模型设置和推理结果
    Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");
        // CPU
    Ort::Session session_{env, model_path, Ort::SessionOptions{nullptr}}; 
        // 模型输入尺寸
    static constexpr const int height_ = 640; //model input height
    static constexpr const int width_ = 640; //model input width
    Ort::Value input_tensor_{nullptr};
    std::array<int64_t, 4> input_shape_{1, 3, height_, width_}; //mode input shape NCHW = 1x3xHxW
        // 模型输出尺寸
    Ort::Value output_tensor_{nullptr};
    std::array<int64_t, 3> output_shape_{1, 9, 8400}; //model output shape,
    std::array<_Float32, 9*8400> results_{};

        // 模型输入输出张量设置
    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
    input_tensor_ = Ort::Value::CreateTensor<float>(memory_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());
    output_tensor_ = Ort::Value::CreateTensor<float>(memory_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size());
        // 查看模型输入输出的名称
    const char* input_names[] = {"images"};
    const char* output_names[] = {"output0"};
        // 推理
    session_.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor_, 1, output_names, &output_tensor_, 1);
    float* out = output_tensor_.GetTensorMutableData<float>();

        // 推理结果获取
    int rows = 9;      // 第二维度大小，即行数
    int cols = 8400;   // 第三维度大小，即列数
    std::vector<std::vector<float>> matrix(rows, std::vector<float>(cols));
    for (int row = 0; row < rows; ++row) 
        for (int col = 0; col < cols; ++col) 
            matrix[row][col] = out[row * cols + col];
        // 9，8400数组转置为8400，9
    std::vector<std::vector<float>> tran_matrix = transpose(matrix);
    //     // 显示缺陷筛选结果
    // std::vector<std::vector<float>> num = tran_matrix;
    // for (size_t n = 0; n < num.size(); ++n) {
    //     bool aboveThreshold = false;
    //     for (size_t col = 4; col <= 8; ++col)
    //         if (num[n][col] > threshold) {
    //             aboveThreshold = true;
    //             break;
    //         }
        
    //     if (aboveThreshold) {
    //         std::cout << "Row " << n << ": ";
    //         for (const auto& val : num[n]) 
    //             std::cout << val << " ";
                
    //         std::cout << std::endl;
    //     }
    // }

    // 缺陷还原
    std::vector<std::vector<double>> select_matrix;
    select_matrix = select(tran_matrix, threshold, cols,rows);
        // 缺陷位置信息还原
    select_matrix = return_(select_matrix, y, x);
        // 缺陷位置信息筛选
    select_matrix = nms_(select_matrix, nms_threshold);
    //     // 打印数组的内容
    // for (const auto& row : select_matrix){
    //     for (const auto& value : row) {
    //         std::cout << value << " ";
    //     }
    //     std::cout << std::endl;
    // }
        // 绘制识别框
    cv::Mat outputImage = draw_image(select_matrix, inputImage, typeNames, colors);
    
    // 自定义窗口大小
    int windowWidth = 1200;
    int windowHeight = 900;

    // 调整窗口大小
    cv::namedWindow("Image with Bounding Boxes", cv::WINDOW_NORMAL);
    cv::resizeWindow("Image with Bounding Boxes", windowWidth, windowHeight);
    cv::imshow("Image with Bounding Boxes", outputImage);
    cv::imwrite("marked_image.jpg", outputImage);
    cv::waitKey(0);

    return 0;
}